¿Cuál es la forma más fácil de transformar el tensor de forma (tamaño_de_lote, alto, ancho) relleno con n valores en un tensor de forma (tamaño_de_lote, n, alto, ancho)? Creé la solución a continuación, pero parece que hay una forma más fácil y rápida de hacer esto def batch_tensor_to_onehot (tnsr, clases): tnsr = tnsr.unsqueeze (1) res = [] para cls en rango (clases): res.append ((tnsr == cls) .long ()) return antorcha.cat (res, dim = 1)
2021-02-20 08:20:19
Puede utilizar torch.nn.functional.one_hot. Para tu caso: a = antorcha.nn.functional.one_hot (tnsr, num_classes = clases) out = a.permute (0, 3, 1, 2) | También puede usar Tensor.scatter_, que evita .permute, pero posiblemente sea más difícil de entender que el método sencillo propuesto por @Alpha. def batch_tensor_to_onehot (tnsr, clases): resultado = antorcha.zeros (tnsr.shape [0], clases, * tnsr.shape [1:], dtype = antorcha.long, dispositivo = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) devolver resultado Resultados de la evaluación comparativa Tenía curiosidad y decidí comparar los tres enfoques. Descubrí que no parece haber una diferencia relativa significativa entre los métodos propuestos con respecto al tamaño del lote, el ancho o la altura. Principalmente, el número de clases fue el factor distintivo. Por supuesto, como con cualquier referencia, el kilometraje puede variar. Los puntos de referencia se recopilaron utilizando índices aleatorios y utilizando el tamaño de lote, altura, ancho = 100. Cada experimento se repitió 20 veces y se informó el promedio. El experimento num_classes = 100 se ejecuta una vez antes de generar perfiles para el calentamiento. Los resultados de la CPU muestran que el método original probablemente fue mejor para num_classes menos de aproximadamente 30, mientras que para GPU el enfoque scatter_ parece ser el más rápido. Pruebas realizadas en Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K El código utilizado para la evaluación comparativa se proporciona a continuación: antorcha de importación desde tqdm import tqdm tiempo de importación importar matplotlib.pyplot como plt def batch_tensor_to_onehot_slavka (tnsr, classes): tnsr = tnsr.unsqueeze (1) res = [] para cls en rango (clases): res.append ((tnsr == cls) .long ()) return antorcha.cat (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, clases): resultado = antorcha.nn.functional.one_hot (tnsr, num_classes = clases) return result.permute (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, classes): resultado = torch.zeros (tnsr.shape [0], clases, * tnsr.shape [1:], dtype = torch.long, dispositivo = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) devolver resultado def main (): num_classes = [2, 10, 25, 50, 100] altura = 100 ancho = 100 bs = [100] * 20 para d en ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] calentamiento = Verdadero para c en tqdm ([num_classes [-1]] + num_classes, ncols = 0): tslavka = 0 talpha = 0 tjodag = 0 para b en bs: tnsr = antorcha.randint (c, (b, altura, ancho)). a (dispositivo = d) t0 = tiempo.tiempo () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () tslavka + = tiempo.tiempo () - t0 si no es calentamiento: times_slavka.append (tslavka / len (bs)) para b en bs: tnsr = antorcha.randint (c, (b, altura, ancho)). a (dispositivo = d) t0 = tiempo.tiempo () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () talpha + = tiempo.tiempo () - t0 si no es calentamiento: times_alpha.append (talpha / len (bs)) para b en bs: tnsr = antorcha.randint (c, (b, altura, ancho)). a (dispositivo = d) t0 = tiempo.tiempo () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = tiempo.tiempo () - t0 si no es calentamiento: times_jodag.append (tjodag / len (bs)) warmup = Falso fig = plt.figure () ax = fig.subplots () ax.plot (num_classes, times_slavka, label = 'Slavka-cat') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('num_classes') ax.set_ylabel ('tiempo (s)') ax.set_title (f '{d} punto de referencia') ax.legend () plt.savefig (f '{d} .png') plt.show () if __name__ == "__main__": principal() | Tu respuesta StackExchange.ifUsing ("editor", function () { StackExchange.using ("editor externo", función () { StackExchange.using ("fragmentos", función () { StackExchange.snippets.init (); }); }); }, "fragmentos de código"); StackExchange.ready (function () { var channelOptions = { etiquetas: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("editor externo", función () { // Debe activar el editor después de los fragmentos, si los fragmentos están habilitados if (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("fragmentos", función () { createEditor (); }); } demás { createEditor (); } }); function createEditor () { StackExchange.prepareEditor ({ useStacksEditor: false, heartbeatType: 'respuesta', autoActivateHeartbeat: falso, convertImagesToLinks: verdadero, noModals: cierto, showLowRepImageUploadWarning: true, reputacionToPostImages: 10, bindNavPrevention: verdadero, sufijo: "", imageUploader: { brandingHtml: "Desarrollado por \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.61182 46.725645.40531 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C41.5986 5.28832 4.6 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.0034 4.66232C32 fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 5.28821 30 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.913 25.3752.416 13.713.4913.290 C28. 1256 12.8854 28.1301 12.9342 28.1301 27.2502 15.2321 14.4373 12.983C28.1301 25.777 15.2321C24.8349 24.1352 14.9821 23.5661 15.2321 14.7787C23.176 22.8472 14.5218 22.5437 14.6393 14.5218C21.7977 21.2429 15.0123 21.2429 14.5218 15.6887C21.2429 22,9072 17,6335 25,6622 16,7375 9,27932 17.6335ZM24.1317 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.279323Ze \ "/ 3. 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 4.88013C12.6531 5.07.98375 13.3609 5.88013C12.65831 7,57.980 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.4375 13.87962. C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821 0.31335512C9049.3913 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846 3.57676 1.8761C0 2.87869 0.822846 3.57676 1.8761C0 3.5006 3.000 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e", contentPolicyHtml: "Contribuciones de usuario con licencia bajo \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (política de contenido) \ u003c / a \ u003e", allowUrls: verdadero }, onDemand: verdadero, discardSelector: ".discard-answer" , inmediatamenteShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); ¡Gracias por contribuir con una respuesta a Stack Overflow! Asegúrese de responder la pregunta. ¡Proporcione detalles y comparta su investigación! Pero evita ... Pedir ayuda, aclaraciones o responder a otras respuestas. Hacer declaraciones basadas en opiniones; respóndelos con referencias o experiencia personal. Para obtener más información, consulte nuestros consejos sobre cómo escribir buenas respuestas. Borrador guardado Borrador descartado Regístrate o inicia sesión StackExchange.ready (function () { StackExchange.helpers.onClickDraftSave ('# login-link'); }); Regístrese con Google Registrarse usando Facebook Regístrese con correo electrónico y contraseña Enviar Publicar como invitado Nombre Correo electrónico Requerido, pero nunca mostrado StackExchange.ready ( function () { StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' ); } ); Publicar como invitado Nombre Correo electrónico Requerido, pero nunca mostrado Publica tu respuesta Descarte Al hacer clic en "Publicar su respuesta", acepta nuestros términos de servicio, política de privacidad y política de cookies. No es la respuesta que estás buscando? Lea otras preguntas con la etiqueta python pytorch tensor one-hot-encoding o haga su propia pregunta.